'''
Description: 基于分割的文本检测网络，识别文字框,并进行图片方向矫正
Version: 2.0
Date: 2022-01-27 22:06:29
LastEditors: Meroke
LastEditTime: 2022-05-17 22:50:39
'''


import onnxruntime as rt
import  numpy as np
import time
import cv2
from .decode import  SegDetectorRepresenter
from math import sqrt,asin
import os

'''
    初始化全局参数
'''
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

def Singleton(cls):
    _instance = {}

    def _singleton(*args, **kargs):
        if cls not in _instance:
            _instance[cls] = cls(*args, **kargs)
        return _instance[cls]

    return _singleton


class SingletonType(type):
    def __init__(cls, *args, **kwargs):
        super(SingletonType, cls).__init__(*args, **kwargs)

    def __call__(cls, *args, **kwargs):
        obj = cls.__new__(cls, *args, **kwargs)
        cls.__init__(obj, *args, **kwargs)
        return obj


def draw_bbox(img_path, result, color=(255, 0, 0), thickness=2):
    if isinstance(img_path, str):
        img_path = cv2.imread(img_path)
        # img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)
    img_path = img_path.copy()
    for point in result:
        point = point.astype(int)

        cv2.polylines(img_path, [point], True, color, thickness)
    return img_path

def draw_max_bbox(img_path, max_box_points, color=(255, 0, 0), thickness=2):
    if isinstance(img_path, str):
        img_path = cv2.imread(img_path)
        # img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)
    img_path = img_path.copy()
    point = max_box_points.astype(int)
    # point[0][0] -= 10
    # point[0][1] -= 10

    # point[1][0] += 10
    # point[1][1] -= 10

    # point[2][0] += 10
    # point[2][1] += 10

    # point[3][0] -= 10
    # point[3][1] += 10
    # !!import 
    # cv2.polylines(img_path, [point], True, color, thickness)

    # 标出文字框的四角对应下标
    # for i in range(len(max_box_points)):
    #     cv2.putText(img_path,str(i),(max_box_points[i][0]+2,max_box_points[i][1]+2),
    #                     cv2.FONT_HERSHEY_SIMPLEX, 2, color, thickness)
    return img_path

class DBNET(metaclass=SingletonType):
    def __init__(self, MODEL_PATH):
        self.sess = rt.InferenceSession(MODEL_PATH)

        self.decode_handel = SegDetectorRepresenter()

    def process(self, img, short_size):

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w = img.shape[:2]
        if short_size !=0:
            if h < w:
                scale_h = short_size / h
                tar_w = w * scale_h
                tar_w = tar_w - tar_w % 32
                tar_w = max(32, tar_w)
                scale_w = tar_w / w

            else:
                scale_w = short_size / w
                tar_h = h * scale_w
                tar_h = tar_h - tar_h % 32
                tar_h = max(32, tar_h)
                scale_h = tar_h / h
        

        img = cv2.resize(img, None, fx=scale_w, fy=scale_h)

        img = img.astype(np.float32)

        img /= 255.0
        img -= mean
        img /= std
        img = img.transpose(2, 0, 1)
        transformed_image = np.expand_dims(img, axis=0)
        out = self.sess.run(["out1"], {"input0": transformed_image.astype(np.float32)})
        box_list, score_list = self.decode_handel(out[0][0], h, w)
        if len(box_list) > 0:
            # reshape() 将列表分为 box_list.shape[0] 行，且计算每行的和， 筛选和大于1的行
            idx = box_list.reshape(box_list.shape[0], -1).sum(axis=1) > 0  # 去掉全为0的框
            box_list, score_list = box_list[idx], score_list[idx]
        else:
            box_list, score_list = [], []
        return box_list, score_list


if __name__ == "__main__":
    text_handle = DBNET(MODEL_PATH="../models/dbnet.onnx")
    # img = cv2.imread("../test_imgs/1.jpg")
    img = cv2.imread("/home/meroke/text.png")
    img= cv2.resize(img,(320,320))
    print(img.shape)
    box_list, score_list = text_handle.process(img,320)
    img = draw_bbox(img, box_list)
    cv2.imwrite("test.jpg", img)
